{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/main/pytorch/notebooks/Serving%20PyTorch%20Models%20with%20CMLE%20%20Custom%20Prediction%20Code.ipynb\"\">\n",
    "      <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Colab logo\"> Run in Colab\n",
    "    </a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://github.com/GoogleCloudPlatform/cloudml-samples/blob/main/pytorch/notebooks/Serving%20PyTorch%20Models%20with%20CMLE%20%20Custom%20Prediction%20Code.ipynb\">\n",
    "      <img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\">\n",
    "      View on GitHub\n",
    "    </a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Yy8l-95JkRx0"
   },
   "source": [
    "# Overview\n",
    "\n",
    "AI Platform Online Prediction now supports custom python code in to apply custom prediction routines, including custom (stateful) pre/post processing, and/or models not created by the standard supported frameworks (TensorFlow, Keras, Scikit-learn, XGBoost).\n",
    "\n",
    "### Dataset\n",
    "\n",
    "We use the [Iris dataset](https://archive.ics.uci.edu/ml/datasets/Iris)\n",
    "\n",
    "### Objective\n",
    "\n",
    "In this notebook, we show how to deploy a model created by [PyTorch](https://pytorch.org/) using AI Platform  Custom Prediction Code using Iris dataset for a multi-class classification problem.\n",
    "\n",
    "### Costs \n",
    "\n",
    "This tutorial uses billable components of Google Cloud Platform (GCP):\n",
    "\n",
    "* Cloud AI Platform\n",
    "* Cloud Storage\n",
    "\n",
    "Learn about [Cloud AI Platform\n",
    "pricing](https://cloud.google.com/ml-engine/docs/pricing) and [Cloud Storage\n",
    "pricing](https://cloud.google.com/storage/pricing), and use the [Pricing\n",
    "Calculator](https://cloud.google.com/products/calculator/)\n",
    "to generate a cost estimate based on your projected usage."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "D5AirRX0kRx2"
   },
   "source": [
    "### Set up your local development environment\n",
    "\n",
    "**If you are using Colab or AI Platform Notebooks**, your environment already meets\n",
    "all the requirements to run this notebook. You can skip this step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Otherwise**, make sure your environment meets this notebook's requirements.\n",
    "You need the following:\n",
    "\n",
    "* The Google Cloud SDK\n",
    "* Git\n",
    "* Python 3\n",
    "* virtualenv\n",
    "* Jupyter notebook running in a virtual environment with Python 3\n",
    "\n",
    "The Google Cloud guide to [Setting up a Python development\n",
    "environment](https://cloud.google.com/python/setup) and the [Jupyter\n",
    "installation guide](https://jupyter.org/install) provide detailed instructions\n",
    "for meeting these requirements. The following steps provide a condensed set of\n",
    "instructions:\n",
    "\n",
    "1. [Install and initialize the Cloud SDK.](https://cloud.google.com/sdk/docs/)\n",
    "\n",
    "2. [Install Python 3.](https://cloud.google.com/python/setup#installing_python)\n",
    "\n",
    "3. [Install\n",
    "   virtualenv](https://cloud.google.com/python/setup#installing_and_using_virtualenv)\n",
    "   and create a virtual environment that uses Python 3.\n",
    "\n",
    "4. Activate that environment and run `pip install jupyter` in a shell to install\n",
    "   Jupyter.\n",
    "\n",
    "5. Run `jupyter notebook` in a shell to launch Jupyter.\n",
    "\n",
    "6. Open this notebook in the Jupyter Notebook Dashboard."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up your GCP project\n",
    "\n",
    "**The following steps are required, regardless of your notebook environment.**\n",
    "\n",
    "1. [Select or create a GCP project.](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.\n",
    "\n",
    "2. [Make sure that billing is enabled for your project.](https://cloud.google.com/billing/docs/how-to/modify-project)\n",
    "\n",
    "3. [Enable the AI Platform APIs and Compute Engine APIs.](https://console.cloud.google.com/flows/enableapi?apiid=ml.googleapis.com,compute_component)\n",
    "\n",
    "4. Enter your project ID in the cell below. Then run the  cell to make sure the\n",
    "Cloud SDK uses the right project for all the commands in this notebook.\n",
    "\n",
    "**Note**: Jupyter runs lines prefixed with `!` as shell commands, and it interpolates Python variables prefixed with `$` into these commands."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Authenticate your GCP account\n",
    "\n",
    "**If you are using AI Platform Notebooks**, your environment is already\n",
    "authenticated. Skip this step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**If you are using Colab**, run the cell below and follow the instructions\n",
    "when prompted to authenticate your account via oAuth.\n",
    "\n",
    "**Otherwise**, follow these steps:\n",
    "\n",
    "1. In the GCP Console, go to the [**Create service account key**\n",
    "   page](https://console.cloud.google.com/apis/credentials/serviceaccountkey).\n",
    "\n",
    "2. From the **Service account** drop-down list, select **New service account**.\n",
    "\n",
    "3. In the **Service account name** field, enter a name.\n",
    "\n",
    "4. From the **Role** drop-down list, select\n",
    "   **Machine Learning Engine > AI Platform Admin** and\n",
    "   **Storage > Storage Object Admin**.\n",
    "\n",
    "5. Click *Create*. A JSON file that contains your key downloads to your\n",
    "local environment.\n",
    "\n",
    "6. Enter the path to your service account key as the\n",
    "`GOOGLE_APPLICATION_CREDENTIALS` variable in the cell below and run the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "# If you are running this notebook in Colab, run this cell and follow the\n",
    "# instructions to authenticate your GCP account. This provides access to your\n",
    "# Cloud Storage bucket and lets you submit training jobs and prediction\n",
    "# requests.\n",
    "\n",
    "if 'google.colab' in sys.modules:\n",
    "  from google.colab import auth as google_auth\n",
    "  google_auth.authenticate_user()\n",
    "\n",
    "# If you are running this notebook locally, replace the string below with the\n",
    "# path to your service account key and run this cell to authenticate your GCP\n",
    "# account.\n",
    "else:\n",
    "  %env GOOGLE_APPLICATION_CREDENTIALS ''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PIP Install Packages and dependencies\n",
    "\n",
    "Before we start let's install pytorch and gcloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 5726,
     "status": "ok",
     "timestamp": 1538245666639,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "b_4YHbyCkRx2",
    "outputId": "7caa60de-641d-4cb4-aab1-38ba0e9db975"
   },
   "outputs": [],
   "source": [
    "!pip install torch --user"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5pHlrb0slo3i"
   },
   "source": [
    "If you are running this notebook in Colab, run the following cell to authenticate your Google Cloud Platform user account"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 5530,
     "status": "ok",
     "timestamp": 1538245790469,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "lQq3X_e_kRx9",
    "outputId": "eb142794-f9c9-4baf-e584-4629c7a90c72"
   },
   "outputs": [],
   "source": [
    "PROJECT = '' # TODO (Set to your GCP Project name)\n",
    "BUCKET = '' # TODO (Set to your GCS Bucket name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gcloud config set project {PROJECT}\n",
    "!gcloud config get-value project"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "zgIaQaRakRyC"
   },
   "source": [
    "## 3. Download iris data\n",
    "In this example, we want to build a classifier for the simple [iris dataset](https://archive.ics.uci.edu/ml/datasets/iris). So first, we download the data csv file locally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5zu_lz-RkRyD"
   },
   "outputs": [],
   "source": [
    "!mkdir data\n",
    "!mkdir models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LOCAL_DATA_DIR = \"data/iris.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 985,
     "status": "ok",
     "timestamp": 1538245803249,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "CNz7zOYYkRyG",
    "outputId": "fe3856f7-3b9f-46fa-eb5f-a86d122b3a8c"
   },
   "outputs": [],
   "source": [
    "from urllib.request import urlretrieve\n",
    "\n",
    "urlretrieve(\"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data\", LOCAL_DATA_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fYg0djCukRyK"
   },
   "source": [
    "# Part 1: Build a PyTorch NN Classifier\n",
    "\n",
    "Make sure that pytorch package is [installed](https://pytorch.org/get-started/locally/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 7065,
     "status": "ok",
     "timestamp": 1538245839403,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "oB3betHBkRyL",
    "outputId": "28de856b-1663-47c1-a22c-cebafdd13a36"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "\n",
    "print('PyTorch Version: {}'.format(torch.__version__))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qr0w3nDvkRyO"
   },
   "source": [
    "## 1. Load Data \n",
    "In this step, we are going to:\n",
    "1. Load the data to Pandas Dataframe.\n",
    "2. Convert the class feature (species) from string to a numeric indicator.\n",
    "3. Split the Dataframe into input feature (xtrain) and target feature (ytrain)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 509,
     "status": "ok",
     "timestamp": 1538245841842,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "lk810wj8kRyP",
    "outputId": "a6788660-f35a-4324-b34f-60e5c28efc7b"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "CLASS_VOCAB = ['setosa', 'versicolor', 'virginica']\n",
    "\n",
    "datatrain = pd.read_csv(LOCAL_DATA_DIR, names=['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'])\n",
    "\n",
    "#change string value to numeric\n",
    "datatrain.loc[datatrain['species']=='Iris-setosa', 'species']=0\n",
    "datatrain.loc[datatrain['species']=='Iris-versicolor', 'species']=1\n",
    "datatrain.loc[datatrain['species']=='Iris-virginica', 'species']=2\n",
    "datatrain = datatrain.apply(pd.to_numeric)\n",
    "\n",
    "#change dataframe to array\n",
    "datatrain_array = datatrain.as_matrix()\n",
    "\n",
    "#split x and y (feature and target)\n",
    "xtrain = datatrain_array[:,:4]\n",
    "ytrain = datatrain_array[:,4]\n",
    "\n",
    "input_features = xtrain.shape[1]\n",
    "num_classes = len(CLASS_VOCAB)\n",
    "\n",
    "print('Records loaded: {}'.format(len(xtrain)))\n",
    "print('Number of input features: {}'.format(input_features))\n",
    "print('Number of classes: {}'.format(num_classes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eCYbi-h7kRyU"
   },
   "source": [
    "## 2. Set model parameters\n",
    "You can try different values for **hidden_units** or **learning_rate**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qvyvyolgkRyU"
   },
   "outputs": [],
   "source": [
    "HIDDEN_UNITS = 10\n",
    "LEARNING_RATE = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JzAVqhvCkRyX"
   },
   "source": [
    "## 3. Define the PyTorch NN model\n",
    "\n",
    "Here, we build a a neural network with one hidden layer, and a Softmax output layer for classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4BP-U3gPkRyX"
   },
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(input_features, HIDDEN_UNITS),\n",
    "    torch.nn.Sigmoid(),\n",
    "    torch.nn.Linear(HIDDEN_UNITS, num_classes),\n",
    "    torch.nn.Softmax()\n",
    ")\n",
    "\n",
    "loss_metric = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(),lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bM79nn7HkRyZ"
   },
   "source": [
    "## 4. Train the model\n",
    "We are going to train the model for **num_epoch** epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 258
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6461,
     "status": "ok",
     "timestamp": 1538245853213,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "ElikXuyokRya",
    "outputId": "c6226007-47fb-4ea7-faee-fdce8db0049a"
   },
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 10000\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    x = Variable(torch.Tensor(xtrain).float())\n",
    "    y = Variable(torch.Tensor(ytrain).long())\n",
    "    optimizer.zero_grad()\n",
    "    y_pred = model(x)\n",
    "    loss = loss_metric(y_pred, y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    if (epoch) % 1000 == 0:\n",
    "        print('Epoch [{}/{}] Loss: {}'.format(epoch+1, NUM_EPOCHS, round(loss.item(),3)))\n",
    "        \n",
    "print('Epoch [{}/{}] Loss: {}'.format(epoch+1, NUM_EPOCHS, round(loss.item(),3)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pzE7_e_kkRyc"
   },
   "source": [
    "## 5. Save and load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TVGH-GS9kRye"
   },
   "outputs": [],
   "source": [
    "LOCAL_MODEL_DIR = \"models/model.pt\"\n",
    "\n",
    "\n",
    "torch.save(model, LOCAL_MODEL_DIR)\n",
    "iris_classifier = torch.load(LOCAL_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nbb4S2s5kRyh"
   },
   "source": [
    "## 6. Test the loaded model for predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3E_i7qpVkRyi"
   },
   "outputs": [],
   "source": [
    "def predict_class(instances):\n",
    "    instances = torch.Tensor(instances)\n",
    "    output = iris_classifier(instances)\n",
    "    _ , predicted = torch.max(output, 1)\n",
    "    return predicted"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7Jdl1MlVkRyn"
   },
   "source": [
    "Get predictions for the first 5 instances in the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 638,
     "status": "ok",
     "timestamp": 1538245859222,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "HEyiYiRckRyn",
    "outputId": "66e52ccc-dac9-47e5-e47c-83572c202557"
   },
   "outputs": [],
   "source": [
    "predicted = predict_class(xtrain[0:5])\n",
    "print([CLASS_VOCAB[class_index] for class_index in predicted])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the classification accuracy on the training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "accuracy = round(sum(np.array(predict_class(xtrain)) == ytrain)/float(len(ytrain))*100,2)\n",
    "print('Classification accuracy: {} %'.format(accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vC4N_80wkRyr"
   },
   "source": [
    "## 7. Upload trained model to Cloud Storage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8205,
     "status": "ok",
     "timestamp": 1538245870922,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "Ue5XadCLkRys",
    "outputId": "27d285d6-3efd-44bb-ae7d-b1910a46dfbe"
   },
   "outputs": [],
   "source": [
    "GCS_MODEL_DIR='models/pytorch/iris_classifier/'\n",
    "\n",
    "!gsutil -m cp -r {LOCAL_MODEL_DIR} gs://{BUCKET}/{GCS_MODEL_DIR}\n",
    "!gsutil ls gs://{BUCKET}/{GCS_MODEL_DIR}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aQP0lGikkRyw"
   },
   "source": [
    "# Part 2: Prepare the Custom Prediction Package\n",
    "\n",
    "1. Implement a model **custom class** for pre/post processing, as well as loading and using your model for prediction.\n",
    "2. Prepare yout **setup.py** file, to include all the modules and packages you need in your custome model class."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KQDLnysXkRyw"
   },
   "source": [
    "## 1. Create the custom model class\n",
    "In the **from_path**, you load the pytorch model that you uploaded to GCS. Then in the **predict** method, you use it for prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 625,
     "status": "ok",
     "timestamp": 1538245874513,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "fg6f8pP8kRyx",
    "outputId": "1a9feaf0-e515-4f3a-fe13-f96d4a1310a8"
   },
   "outputs": [],
   "source": [
    "%%writefile model.py\n",
    "\n",
    "import os\n",
    "import pandas as pd\n",
    "from google.cloud import storage\n",
    "import torch\n",
    "\n",
    "class PyTorchIrisClassifier(object):\n",
    "    \n",
    "    def __init__(self, model):\n",
    "        self._model = model\n",
    "        self.class_vocab = ['setosa', 'versicolor', 'virginica']\n",
    "        \n",
    "    @classmethod\n",
    "    def from_path(cls, model_dir):\n",
    "        model_file = os.path.join(model_dir,'model.pt')\n",
    "        model = torch.load(model_file)    \n",
    "        return cls(model)\n",
    "\n",
    "    def predict(self, instances, **kwargs):\n",
    "        data = pd.DataFrame(instances).as_matrix()\n",
    "        inputs = torch.Tensor(data)\n",
    "        outputs = self._model(inputs)\n",
    "        _ , predicted = torch.max(outputs, 1)\n",
    "        return [self.class_vocab[class_index] for class_index in predicted]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Y22oukOdkRy1"
   },
   "source": [
    "## 2. Create a setup.py module\n",
    "Do not include **pytorch** as a required package, as well as the **model.py** file that includes your custom model class. We will include it when creating the model below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 560,
     "status": "ok",
     "timestamp": 1538245877594,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "QjOxQt2lkRy2",
    "outputId": "3d06578d-d47b-4f30-d927-9225adb4e762"
   },
   "outputs": [],
   "source": [
    "%%writefile setup.py\n",
    "\n",
    "from setuptools import setup\n",
    "\n",
    "REQUIRED_PACKAGES = []\n",
    "\n",
    "setup(\n",
    "    name=\"iris-custom-model\",\n",
    "    version=\"0.1\",\n",
    "    scripts=[\"model.py\"],\n",
    "    install_requires=REQUIRED_PACKAGES\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AMopNrL0kRy6"
   },
   "source": [
    "## 3. Create the package \n",
    "\n",
    "This will create a .tar.gz package under /dist directory. The name of the package will be (name)-(version).tar.gz where (name) and (version) are the ones specified in the setup.py."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 544
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3243,
     "status": "ok",
     "timestamp": 1538245884136,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "MhDjxWNKkRy7",
    "outputId": "209288b7-cece-4a06-8595-4764921691e0"
   },
   "outputs": [],
   "source": [
    "!python setup.py sdist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Y0HtInoxkRy_"
   },
   "source": [
    "## 4. Uploaded the package to GCS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 7247,
     "status": "ok",
     "timestamp": 1538245946830,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "1rY8YgOEkRzA",
    "outputId": "bc6db003-2d0c-4470-eaf9-78e51e23109d"
   },
   "outputs": [],
   "source": [
    "GCS_PACKAGE_URI='models/pytorch/packages/iris-custom-model-0.1.tar.gz'\n",
    "\n",
    "!gsutil cp ./dist/iris-custom-model-0.1.tar.gz gs://{BUCKET}/{GCS_PACKAGE_URI}\n",
    "!gsutil ls gs://{BUCKET}/{GCS_PACKAGE_URI}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GhF3pYQTkRzE"
   },
   "source": [
    "# Part 3: Deploy the Model to AI Platform for Online Predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-40ZkeakkRzF"
   },
   "source": [
    "## 1. Create AI Platform model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8718,
     "status": "ok",
     "timestamp": 1538245958274,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "muG2T57skRzG",
    "outputId": "d6747346-1f2b-4382-a8ca-8e9535a6c5be"
   },
   "outputs": [],
   "source": [
    "MODEL_NAME='torch_iris_classifier'\n",
    "REGION = 'us-central1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can uncomment to enable logging\n",
    "!gcloud ai-platform models create {MODEL_NAME} --regions {REGION} #--enable-logging --enable-console-logging\n",
    "!gcloud ai-platform models list | grep 'torch'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bNbqAHgFkRzI"
   },
   "source": [
    "## 2. Create AI Platform model version\n",
    "\n",
    "Once you have your custom package ready, you can specify this as an argument when creating a version resource. Note that you need to provide the path to your package (as package-uris) and also the class name that contains your custom predict method (as model-class)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pytorch compatible packages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You need to use compiled packages compatible with Cloud AI Platform Package information here\n",
    "\n",
    "This bucket containers compiled packages for PyTorch that are compatible with Cloud AI Platform prediction. The files are mirroed from the official builds at https://download.pytorch.org/whl/cpu/torch_stable.html\n",
    "\n",
    "In order to deploy a PyTorch model on Cloud AI Platform Online Predictions, you must add one of these packages to the packageURIs field on the version you deploy. Pick the package matching your Python and PyTorch version. The package names follow this template:\n",
    "\n",
    "Package name = torch-{TORCH_VERSION_NUMBER}-{PYTHON_VERSION}-linux_x86_64.whl where PYTHON_VERSION = cp35-cp35m for Python 3 with runtime versions < 1.15, cp37-cp37m for Python 3 with runtime versions >= 1.15\n",
    "\n",
    "Use cp27-cp27mu for Python 2.\n",
    "\n",
    "For example, if I were to deploy a PyTorch model based on PyTorch 1.1.0 and Python 3, my gcloud command would look like:\n",
    "\n",
    "gcloud beta ai-platform versions create {VERSION_NAME} --model {MODEL_NAME} \\\n",
    "...\n",
    "--package-uris=gs://{MY_PACKAGE_BUCKET}/my_package-0.1.tar.gz,gs://cloud-ai-pytorch/torch-1.1.0-cp35-cp35m-linux_x86_64.whl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gpXx-nFLkRzJ"
   },
   "outputs": [],
   "source": [
    "MODEL_VERSION='v3'\n",
    "RUNTIME_VERSION='1.15'\n",
    "MODEL_CLASS='model.PyTorchIrisClassifier'\n",
    "\n",
    "!gcloud beta ai-platform versions create {MODEL_VERSION} --model={MODEL_NAME} \\\n",
    "            --origin=gs://{BUCKET}/{GCS_MODEL_DIR} \\\n",
    "            --python-version=3.7 \\\n",
    "            --runtime-version={RUNTIME_VERSION} \\\n",
    "            --machine-type=mls1-c4-m4 \\\n",
    "            --package-uris=gs://{BUCKET}/{GCS_PACKAGE_URI},gs://cloud-ai-pytorch/torch-1.3.1+cpu-cp37-cp37m-linux_x86_64.whl \\\n",
    "            --prediction-class={MODEL_CLASS}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3411,
     "status": "ok",
     "timestamp": 1538246126033,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "naG7FtQokRzM",
    "outputId": "2df0240e-1445-47e8-9a41-627dfacab273"
   },
   "outputs": [],
   "source": [
    "!gcloud ai-platform versions list --model {MODEL_NAME}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TQz-JaqpkRzP"
   },
   "source": [
    "# Part 4: AI Platform Online Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Wdx5nTxjkRzR"
   },
   "outputs": [],
   "source": [
    "from googleapiclient import discovery\n",
    "from oauth2client.client import GoogleCredentials\n",
    "\n",
    "credentials = GoogleCredentials.get_application_default()\n",
    "api = discovery.build('ml', 'v1', credentials=credentials,\n",
    "                      discoveryServiceUrl='https://storage.googleapis.com/cloud-ml/discovery/ml_v1_discovery.json')\n",
    "\n",
    "\n",
    "def estimate(project, model_name, version, instances):\n",
    "    request_data = {'instances': instances}\n",
    "    model_url = 'projects/{}/models/{}/versions/{}'.format(project, model_name, version)\n",
    "    response = api.projects().predict(body=request_data, name=model_url).execute()\n",
    "\n",
    "    #print response\n",
    "    \n",
    "    predictions = response[\"predictions\"]\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 869,
     "status": "ok",
     "timestamp": 1538246147364,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "ETKlF4yHkRzT",
    "outputId": "f1089ffd-c6c0-4db9-cf64-02bf414a5701"
   },
   "outputs": [],
   "source": [
    "instances = [\n",
    "    [6.8, 2.8, 4.8, 1.4],\n",
    "    [6. , 3.4, 4.5, 1.6]\n",
    "]\n",
    "\n",
    "predictions = estimate(instances=instances\n",
    "                     ,project=PROJECT\n",
    "                     ,model_name=MODEL_NAME\n",
    "                     ,version=MODEL_VERSION)\n",
    "\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nk8oZ2zPkRzW"
   },
   "source": [
    "# Questions? Feedback?\n",
    "Feel free to send us an email (cloudml-feedback@google.com) if you run into any issues or have any questions/feedback!"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Serving PyTorch Models with AI Platform  Custom Prediction Code.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
